import copy
import math
import os
import re
import shutil
import random

import numpy as np
from pydfnworks.general.paths import define_paths
from pydfnworks.general.dfnworks import create_dfn
from pydfnworks.dfnGen.gen_input import check_input
from pydfnworks.dfnGen.generator import make_working_directory, create_network
from pydfnworks.dfnGen.mesh_dfn import mesh_network
from tough_script import dist, write_string


def loadDataSet(fileName, splitChar):
    dataSet = []
    with open(fileName) as fr:
        for line in fr.readlines():
            if line.startswith('#'):
                continue
            curline = line.strip().split(splitChar)
            try:
                fltline = list(map(float, curline))
            except:
                dataSet.append(curline)
            else:
                dataSet.append(fltline)
    return dataSet


def writeDataSet(fileName, lists, splitChar):
    files = open(fileName, 'a')
    for item in range(len(lists)):
        files.write(str(lists[item]))
        if item == len(lists) - 1:
            files.write('\n')
        else:
            files.write(splitChar)
    files.close()


def proj(geo_p):  # calculate analytical parameters of a fracture
    c = math.cos(math.radians(geo_p[4]))
    b = math.sin(math.radians(geo_p[4])) * math.cos(math.radians(geo_p[3]))
    a = math.sin(math.radians(geo_p[4])) * math.sin(math.radians(geo_p[3]))
    return [a, b, c, - (a * geo_p[0] + b * geo_p[1] + c * geo_p[2])]


def distances(point, pa, pit=False):
    part1, part2 = np.dot(pa[:3], point) + pa[3], np.dot(pa[:3], pa[:3])
    if pit:  # calculate projection point
        t = - part1 / part2
        return [point[0] + pa[0] * t, point[1] + pa[1] * t, point[2] + pa[2] * t]
    return abs(part1) / np.sqrt(part2)  # calculate distance


def rotationMatrix(source, sink, eps=1e-7):
    xprod, R = np.cross(source, sink), np.eye(3)
    if not (abs(xprod[0]) <= eps and abs(xprod[1]) <= eps and abs(xprod[2]) <= eps):
        sin, cos = np.dot(xprod, xprod), np.dot(source, sink)
        v = np.array([[0, -xprod[2], xprod[1]], [xprod[2], 0, -xprod[0]], [-xprod[1], xprod[0], 0]])
        scalar = (1.0 - cos) / sin
        vSquared = scalar * np.dot(v, v)
        R = R + v + vSquared
    return R


def subransac(Data2, numfractures, sigma, dips, angles, resects):
    lines, ini_data2, center_a, anal, ptn = 0, len(Data2), [], [], []  # fractures' geological and analytical parameters
    while len(Data2) > 0 and lines + len(resects) < numfractures:
        rcenter = Data2[random.sample(range(len(Data2)), 1)[0]][:3]  # locate a fracture
        gb = random.randint(0, len(dips) - 1)
        dip, angle = random.randint(dips[gb][0], dips[gb][1]), random.randint(angles[gb][0], angles[gb][1])
        rcenter.extend([dip, angle])
        pt = proj(rcenter)
        anal.append(pt)
        fitted = []
        for t in range(len(Data2)):
            distance = distances(Data2[t], pt)
            if distance <= sigma:
                fitted.append(Data2[t])
        for tt in fitted:
            if tt in Data2:
                Data2.remove(tt)
        center_a.append(rcenter)
        lines += 1
    return center_a, anal, len(Data2)


def dbscan(Data, Eps, MinPts):  # for size determination, keep the largest cluster
    unvisited, visited = [i for i in range(len(Data))], []
    C, clustered, passed, reclus = [-1 for _ in range(len(Data))], [], [], []
    k, ma = -1, 0
    while len(unvisited) > 0:
        subcluster, N = [], []
        p = random.choice(unvisited)
        unvisited.remove(p)
        visited.append(p)
        for i in range(len(Data)):
            if dist(Data[i], Data[p]) <= Eps:
                N.append(i)
        if len(N) >= MinPts:
            k += 1
            C[p] = k
            for pi in N:
                if pi in unvisited:
                    unvisited.remove(pi)
                    visited.append(pi)
                    M = []
                    for j in range(len(Data)):
                        if dist(Data[j], Data[pi]) <= Eps:
                            M.append(j)
                    if len(M) >= MinPts:
                        for t in M:
                            if t not in N:
                                N.append(t)
                if C[pi] == -1:
                    C[pi] = k
            for jc in N:
                if Data[jc] not in subcluster:
                    subcluster.append(Data[jc])
            pass_clus = subcluster
            if len(subcluster) > ma:
                ma = len(subcluster)
                pass_clus = copy.deepcopy(clustered)
                clustered = subcluster
            passed.extend(pass_clus)
            reclus.extend([ct for ct in passed if ct not in reclus])  # eliminated points
        else:
            C[p] = -1
            if Data[p] not in reclus:
                reclus.append(Data[p])
    if len(clustered) + len(reclus) != len(Data):  # verify and delete mistakenly eliminated points
        reclus = [jtc for jtc in reclus if jtc not in clustered]
    return clustered, reclus


def singfracgen(num, clc, data, sigma, dips, dip_angles, initial, ana, distg, iters):
    best, bes_anal = [], []
    goal = np.inf
    for j in range(iters):
        befits, restore, anat, dista = copy.deepcopy(initial), copy.deepcopy(data), copy.deepcopy(ana), 0
        center_a, anay, fii = subransac(data, num, sigma, dips, dip_angles, befits)
        for t in range(len(anay)):  # analytical parameters for newly generated fractures
            for f in range(len(clc)):
                dit = distances(clc[f], anay[t])
                dista += dit
        if (dista + distg) / (len(anay) + len(befits)) / (len(clc) - fii) < goal:
            goal = (dista + distg) / (len(anay) + len(befits)) / (len(clc) - fii)  # for proposal distribution
            befits.extend(center_a)
            anat.extend(anay)
            best, bes_anal = copy.deepcopy(befits), copy.deepcopy(anat)  # best geological and analytical parameters
        data = restore
    return best, bes_anal, goal


def get_rid_init_f(inif, seis_all, sigma, dip, angle):
    d_init, ini_unify, ini_frac, ss_g, rest = 0, [], [], [], copy.deepcopy(seis_all)
    mx, my, mz = [m[0] for m in seis_all], [m[1] for m in seis_all], [m[2] for m in seis_all]
    for de in inif:
        if (de[5] == 'p' or de[5] == 'q') and abs(float(de[6])) > 0:
            if de[3] == '*' or de[4] == '*':
                gb = random.randint(0, len(dip) - 1)
                de[3] = str(random.randint(dip[gb][0], dip[gb][1]))
                de[4] = str(random.randint(angle[gb][0], angle[gb][1]))
            if de[0] == '*':
                de[0] = random.uniform(min(mx), max(mx))
            if de[1] == '*':
                de[1] = random.uniform(min(my), max(my))
            if de[2] == '*':
                de[2] = random.uniform(min(mz), max(mz))
            bg = list(map(float, de[:3]))  # source/sink positions
            bg.append(de[5])  # deprecated, source/sink mode
            bg.append(de[6])  # float numbers marking injection and extraction
            bg.append(de[7])  # enthalpy
            bg.extend(list(map(float, de[8:])))
            ss_g.append(bg)
            ed = list(map(float, de[:5]))
            ini_frac.append(ed)  # fracture geological parameters
            nc = proj(ed)
            ini_unify.append(nc)  # fracture analytical parameters
            for a in range(len(seis_all)):
                disa = distances(seis_all[a], nc)
                d_init += disa
                if disa <= sigma and seis_all[a] in rest:
                    rest.remove(seis_all[a])
    return rest, ini_unify, d_init, ini_frac, ss_g


def genfrasize(sigma, option, para, avery2, points, inf, sup):
    redata, tpts, tfit, brjt, unf, sfitp = copy.deepcopy(option), [], [], [], copy.deepcopy(option), []
    for tj in redata:
        sin = [{'distance': distances(tj, para[ij]), 'fracture': ij + 1} for ij in range(len(para))]
        sin = sorted(sin, key=lambda x: x['distance'], reverse=False)
        if sin[0]['distance'] <= 1.0 * sigma:  # unfitted points
            unf.remove(tj)
            tpts.append(tj)
            ssfitp = [ki['fracture'] for ki in sin]  # write fracture according to distances
            sfitp.append(ssfitp)
    print(f'{len(tpts)} out of {len(redata)} points fitted')
    btpts, numrj = copy.deepcopy(tpts), [0 for j in range(len(tpts))]  # rejects for each point
    cfra, centers, tfit2 = [], [], []
    for a in range(len(para)):
        afit, afitp = [], []
        for b in range(len(tpts)):
            if sfitp[b][0] == a + 1:
                sec = distances(btpts[b], para[a], True)
                for cc in range(3):
                    if sec[cc] < inf[cc] or sec[cc] > sup[cc]:  # projection out of the domain
                        brjt.append(btpts[b])
                        cfra.append(sfitp[b])
                        numrj[b] += 1
                        break
                else:
                    afit.append(btpts[b])
                    afitp.append(sec)
        cafit, rej = dbscan(afitp, avery2, points)
        for c, d in zip(afit, afitp):
            if d in rej:
                brjt.append(c)  # rejected MS
                afit.remove(c)
                found = btpts.index(c)
                cfra.append(sfitp[found])  # search for fractures' sequence list of MS
                numrj[found] += 1
        tfit.append(cafit)  # cluster results for accepted points (projections)
        tfit2.append(afit)
    while len(brjt) > 0 and max(numrj) < 4 * len(para):
        print(f'{len(brjt)} points remaining, {max(numrj)} out of {4 * len(para)} completed')
        rejected, tofit, tofit2, cfras = [], [], [], []
        for aca in range(len(tfit)):
            ugb, ugbv = tfit[aca], tfit2[aca]
            for bdb in brjt:
                get1, get2 = brjt.index(bdb), tpts.index(bdb)
                if cfra[get1][numrj[get2] % len(cfra[get1])] == aca + 1:
                    sec = distances(bdb, para[aca], True)
                    for cc in range(3):
                        if sec[cc] < inf[cc] or sec[cc] > sup[cc]:
                            rejected.append(bdb)
                            cfras.append(sfitp[get2])
                            numrj[get2] += 1
                            break
                    else:
                        ugbv.append(bdb)
                        ugb.append(sec)
            cafit, rej = dbscan(ugb, avery2, points)
            for c, d in zip(ugbv, ugb):
                if d in rej:
                    rejected.append(c)
                    ugbv.remove(c)
                    found = tpts.index(c)
                    cfras.append(sfitp[found])
                    numrj[found] += 1
            tofit.append(cafit)
            tofit2.append(ugbv)
        tfit, tfit2 = copy.deepcopy(tofit), copy.deepcopy(tofit2)  # new classification of accepted points
        brjt, cfra = copy.deepcopy(rejected), copy.deepcopy(cfras)
    print(f'finish at {max(numrj)} out of {4 * len(para)} iterations')
    write_f = f'nEllipses: {len(tfit)}\nnNodes: 4\nCoordinates:\n'
    for t in range(len(para)):
        if not tfit[t]:
            break
        real_fitted, cb = dbscan(tfit[t], avery2, points)
        if not real_fitted:
            break
        x, y, zs = [], [], 0.0
        Rt, Rt_inv = rotationMatrix(para[t][:3], [0, 0, 1]), rotationMatrix([0, 0, 1], para[t][:3])
        for tjs in real_fitted:  # calculate projection points to corresponding fractures
            prjp = list(np.dot(Rt, tjs))
            x.append(prjp[0])
            y.append(prjp[1])
            zs += prjp[2]
        zs /= len(x)  # average value of z coordinate to avoid numerical error
        x_tot, y_tot = sum(x) / len(x), sum(y) / len(y)
        fac_1, fac_2 = [a * b for a, b in zip(x, y)], [a ** 2 for a in x]
        slope = (sum(fac_1) - sum(x) * y_tot) / (sum(fac_2) - sum(x) * x_tot)
        alpha = math.degrees(math.atan(slope))  # slope to angle
        if slope < 0:
            alpha += 180
        if 0 <= slope <= 1:
            theta = math.radians(alpha)  # angle to rotation angles (clockwise)
        elif -1 <= slope < 0:
            theta = math.radians(alpha - 180)
        else:
            theta = math.radians(alpha - 90)
        tri_c, tri_s, xj, yj = math.cos(theta), math.sin(theta), [], []
        for j in range(len(x)):
            reg_x, reg_y = x[j] - x_tot, y[j] - y_tot
            xj.append(reg_x * tri_c + reg_y * tri_s)
            yj.append(reg_y * tri_c - reg_x * tri_s)
        xt, yt = np.linspace(min(xj), max(xj), 2), np.linspace(min(yj), max(yj), 2)
        Xt, Yt = np.meshgrid(xt, yt)  # linear array in new coordinate system
        Ztx, Zty = Xt * tri_c - Yt * tri_s + x_tot, Xt * tri_s + Yt * tri_c + y_tot  # rotate vertices in x-y plane
        v_x, v_y = [Ztx[0][0], Ztx[0][1], Ztx[1][1], Ztx[1][0]], [Zty[0][0], Zty[0][1], Zty[1][1], Zty[1][0]]
        for d in range(len(v_x)):
            vertice = list(np.dot(Rt_inv, [v_x[d], v_y[d], zs]))
            write_f += '{' + f'{vertice[0]},{vertice[1]},{vertice[2]}' + '}'
            if d != len(v_x) - 1:
                write_f += '\t'
        write_f += '\n'
        cen = list(np.dot(Rt_inv, [sum(v_x) / 4, sum(v_y) / 4, zs]))
        centers.append(cen)  # solution to center axis and size (equal radius)
    write_string('define_4_user_ellipses_2.dat', 'w', write_f)  # file for further LaGriT mesh
    return centers


def truncation_judge(nodes, left, right, eps=1e-7):
    for i in range(len(nodes) // 3):
        idx = i * 3
        for j in range(3):
            if nodes[idx + j] > right[j] + eps or nodes[idx + j] < left[j] - eps:
                return 1
    else:
        return 0  # no need for truncation


def domainTruncation(nodes, left, right, midpt):
    nNodes = len(nodes) // 3
    flag = [[0, 0, 0] for cc in range(nNodes)]
    for j in range(6):
        if j == 0:
            ntmp, pttmp = [0, 0, 1], [0, 0, right[2]]
        elif j == 1:
            ntmp, pttmp = [0, 0, -1], [0, 0, left[2]]
        elif j == 2:
            ntmp, pttmp = [0, 1, 0], [0, right[1], 0]
        elif j == 3:
            ntmp, pttmp = [0, -1, 0], [0, left[1], 0]
        elif j == 4:
            ntmp, pttmp = [1, 0, 0], [right[0], 0, 0]
        else:
            ntmp, pttmp = [-1, 0, 0], [left[0], 0, 0]
        for i in range(nNodes):
            index = i * 3
            temp = [nodes[index + c] - pttmp[c] for c in range(3)]
            currdist = np.dot(temp, ntmp)
            if currdist > 0:  # current is outside
                if j % 2 == 0:  # current.gt.max
                    flag[i][2 - j // 2] = 1
                else:  # current.lt.min
                    flag[i][2 - j // 2] = -1
    idx, ids, ids2 = 0, [], []
    for ss in flag:  # number of points needs to be moved
        ids.append((idx, ss))
        if np.dot(ss, ss) > 0:  # index of points that should be moved
            ids2.append(idx)
        idx += 1
    new_pts = copy.deepcopy(nodes)
    for ced in range(2):
        vect = [nodes[ijl + 3 * ced] - midpt[ijl] for ijl in range(3)]  # for point ced
        if ced not in ids2 and ced + 2 not in ids2:
            continue
        t = 1.0
        scan, scan_coords, calib = ids[ced][1], nodes[ids[ced][0] * 3: ids[ced][0] * 3 + 3], vect
        for ci in range(2):
            if ci == 1:
                scan, scan_coords = ids[ced + 2][1], nodes[ids[ced][0] * 3 + 6: ids[ced][0] * 3 + 9]
                calib = [- ve for ve in vect]   # for point ced + 2
            for sc in range(3):
                if scan[sc] == 0:
                    continue
                if scan[sc] == -1:
                    t0 = (left[sc] - midpt[sc]) / calib[sc]
                    if t0 < t:
                        t = t0
                elif scan[sc] == 1:
                    t0 = (right[sc] - midpt[sc]) / calib[sc]
                    if t0 < t:
                        t = t0
        x_new_1, y_new_1, z_new_1 = midpt[0] + t * vect[0], midpt[1] + t * vect[1], midpt[2] + t * vect[2]
        x_new_2, y_new_2, z_new_2 = midpt[0] - t * vect[0], midpt[1] - t * vect[1], midpt[2] - t * vect[2]
        idxp = ids[ced][0] * 3
        new_pts[idxp], new_pts[idxp + 1], new_pts[idxp + 2] = x_new_1, y_new_1, z_new_1
        new_pts[idxp + 6], new_pts[idxp + 7], new_pts[idxp + 8] = x_new_2, y_new_2, z_new_2
    return new_pts


def readuserEllByCoord(EllByCoordFileName):
    vertices = []
    with open(EllByCoordFileName) as dsa:
        page = dsa.readlines()
    header = page[0] + page[1] + page[2]
    for polys in page[3:]:  # skip three lines
        vertice = []
        curlyList = polys.strip().split('\t')
        numberOfNodes = len(curlyList)
        for cg in curlyList:
            userEllCoordVertices = list(map(float, re.sub("{|}", "", cg).strip().split(",")))
            vertice.extend(userEllCoordVertices)
        vertices.append(vertice)
    return vertices, header


def truncate_fracture(filename, infe, supb):
    old_verts, title = readuserEllByCoord(filename)  # read original fracture for truncation
    eqradius = []
    for ct in old_verts:
        truncate_flag = truncation_judge(ct, infe, supb)
        mid = [(ct[ijk] + ct[ijk + 3] + ct[ijk + 6] + ct[ijk + 9]) / 4 for ijk in range(3)]
        if truncate_flag == 1:
            ct = domainTruncation(ct, infe, supb, mid)
        v1, v2 = [ct[it + 3] - ct[it] for it in range(3)], [ct[it + 9] - ct[it] for it in range(3)]
        radius = np.sqrt(np.sqrt(np.dot(v1, v1) * np.dot(v2, v2) - np.dot(v1, v2) ** 2) / math.pi)
        eqradius.append(radius)
        nverts = len(ct) // 3
        for gi in range(nverts):
            strs = '{' + f'{ct[gi * 3]},{ct[gi * 3 + 1]},{ct[gi * 3 + 2]}' + '}'
            if gi != nverts - 1:
                strs += '\t'
            title += strs
        title += '\n'
    write_string('define_4_user_ellipses_3.dat', 'w', title)
    return eqradius


if __name__ == '__main__':
    gf = os.getcwd()
    define_paths()
    DFN = create_dfn()
    DFN.make_working_directory()
    DFN.check_input()
    k, post_normal = DFN.create_network(), []  # fractures parameters after FRAM
    well_allow = os.path.isfile(f'{gf}/well_allow.txt')
    pri, ssri = open(f'{gf}/prior_fracs.txt', 'r'), open(f'{gf}/source_sink.txt', 'r')
    prior_fracs, so_sis = pri.readlines(), ssri.readlines()
    pri.close()
    ssri.close()
    if k:  # k=True means fracture generation's success
        with open(DFN.jobname + '/normal_vectors.dat') as oc:
            novs = oc.readlines()  # normal vectors
        num_po = len(novs) - len(so_sis)
        if len(novs) == len(prior_fracs):
            with open(DFN.jobname + '/connectivity.dat') as frc:
                sfgt = frc.readlines()[:num_po]  # eliminate extra fractures for locating wells
            for fra in sfgt:
                stg = fra.strip().split(' ')
                sag = list(map(int, stg))
                frac_check = [t for t in sag if t <= num_po]
                if not frac_check:  # check if fracture is isolated
                    break
            else:
                DFN.mesh_network(wells=well_allow, slope=2)
    os.chdir('../')
    

